# coding = utf-8

import sys

import click
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib2 import Path
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm

import utils.checkpoint as cp
from dataset import KiTS19
from dataset.transform import MedicalTransform,CropTransform
from loss import GeneralizedDiceLoss
from loss.util import class2one_hot
from network import UNet
from utils.metrics import Evaluator
from utils.vis import imshow
import matplotlib.pyplot as plt

def kits19_test():
    transform = CropTransform(output_size=(320, 320), roi_error_range=15, use_roi=False)

    # roi 为 None， 同时valid不做变换
    dataset = KiTS19("/datasets/3Dircadb/chengkun_only_liver", stack_num=5, spec_classes=[0, 1, 2], img_size=(320, 320),
                     use_roi=False, roi_file=None, roi_error_range=5,
                     train_transform=transform, valid_transform=None)

    sampler = RandomSampler(dataset.train_dataset)

    train_loader = DataLoader(dataset.train_dataset, batch_size=1, sampler=sampler,
                              num_workers=1, pin_memory=True)

    for i,data in enumerate(train_loader):
        print(data["image"].shape)




if __name__ == '__main__':
    kits19_test()